{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3bca1827-a930-4a72-b819-394544442e52",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mikeb\\.conda\\envs\\sandbox\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from anthropic import Anthropic\n",
    "from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score\n",
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "04c00ff8-459a-4789-9e93-637d1f02820c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('../data/polnli_test_results.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a03845b8-ebf0-4fe3-97d1-b6f7479a83dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def metrics(df, preds, group_by=None):\n",
    "    true_col = 'entailment'\n",
    "    \n",
    "    def get_metrics(y_true, y_pred):\n",
    "        return {\n",
    "            'MCC': matthews_corrcoef(y_true, y_pred),\n",
    "            'Accuracy': accuracy_score(y_true, y_pred),\n",
    "            'F1': f1_score(y_true, y_pred, average='weighted')\n",
    "        }\n",
    "    \n",
    "    results = []\n",
    "    \n",
    "    if group_by not in ['dataset', 'task']:\n",
    "        for col in preds:\n",
    "            metrics = get_metrics(df[true_col], df[col])\n",
    "            metrics['Column'] = col\n",
    "            results.append(metrics)\n",
    "    else:\n",
    "        for col in preds:\n",
    "            for group_name, group in df.groupby(group_by):\n",
    "                metrics = get_metrics(group[true_col], group[col])\n",
    "                metrics['Column'] = col\n",
    "                metrics[group_by.capitalize()] = group_name\n",
    "                results.append(metrics)\n",
    "    \n",
    "    results_df = pd.DataFrame(results)\n",
    "    \n",
    "    if group_by in ['dataset', 'task']:\n",
    "        return results_df.set_index(['Column', group_by.capitalize()])\n",
    "    else:\n",
    "        return results_df.set_index('Column')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cfe9add7-2f0c-4a59-8911-71cecf397faf",
   "metadata": {},
   "outputs": [],
   "source": [
    "claude = Anthropic(api_key=\"##################\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1f331ef4-5573-4222-ace5-b83abd3b7518",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_message = \"\"\"You are a classifier that can only respond with 0 or 1. I'm going to show you a short text sample and I want you to determine if {hypothesis}. Here is the text:\n",
    "{doc}\n",
    "\n",
    "If it is true that {hypothesis}, return 0. If it is not true that {hypothesis}, return 1.\n",
    "Do not explain your answer, and only return 0 or 1.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "81b7896a-09a3-47f2-a2a0-744de01e14ad",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: total: 35.2 s\n",
      "Wall time: 4h 8min 42s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "user_message = user_message\n",
    "model = \"claude-3-5-sonnet-20240620\"\n",
    "data = test\n",
    "labels = []\n",
    "\n",
    "for i in data.index:\n",
    "    doc = data.loc[i, 'premise']\n",
    "    hypothesis = data.loc[i, 'augmented_hypothesis']\n",
    "    res = claude.messages.create(\n",
    "        max_tokens=2,\n",
    "        messages=[\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": user_message.format(doc = doc, hypothesis = hypothesis),\n",
    "            }\n",
    "        ],\n",
    "        model=model,\n",
    "        temperature = 0\n",
    "    )\n",
    "    labels.append(res.content[0].text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "680ee7a4-04b0-47bd-a97d-46277f9a5c18",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mikeb\\AppData\\Local\\Temp\\ipykernel_12576\\1404300228.py:2: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
      "  test['sonnet'] = test['sonnet'].replace({'1':1, '0':0, '2':-1})\n"
     ]
    }
   ],
   "source": [
    "test['sonnet'] = labels\n",
    "test['sonnet'] = test['sonnet'].replace({'1':1, '0':0, '2':-1})\n",
    "test.to_csv('polnli_test_results.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "aaf114c3-aa2c-4232-9332-0e34804de5c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>MCC</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Column</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>base_polnli</th>\n",
       "      <td>0.894269</td>\n",
       "      <td>0.948978</td>\n",
       "      <td>0.948852</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>large_polnli</th>\n",
       "      <td>0.915911</td>\n",
       "      <td>0.959326</td>\n",
       "      <td>0.959180</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>llama</th>\n",
       "      <td>0.730997</td>\n",
       "      <td>0.862358</td>\n",
       "      <td>0.863467</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>sonnet</th>\n",
       "      <td>0.815902</td>\n",
       "      <td>0.910517</td>\n",
       "      <td>0.909423</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   MCC  Accuracy        F1\n",
       "Column                                    \n",
       "base_polnli   0.894269  0.948978  0.948852\n",
       "large_polnli  0.915911  0.959326  0.959180\n",
       "llama         0.730997  0.862358  0.863467\n",
       "sonnet        0.815902  0.910517  0.909423"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics(test, preds = ['base_polnli', 'large_polnli', 'llama', 'sonnet'], group_by = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a13493d1-8ae4-41d9-8be6-55f0de2759a7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>MCC</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Column</th>\n",
       "      <th>Task</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">base_polnli</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.813742</td>\n",
       "      <td>0.906774</td>\n",
       "      <td>0.907042</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.841410</td>\n",
       "      <td>0.946036</td>\n",
       "      <td>0.945248</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.920600</td>\n",
       "      <td>0.961546</td>\n",
       "      <td>0.961488</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.926241</td>\n",
       "      <td>0.963834</td>\n",
       "      <td>0.963826</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">large_polnli</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.819699</td>\n",
       "      <td>0.909567</td>\n",
       "      <td>0.909838</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.881535</td>\n",
       "      <td>0.959360</td>\n",
       "      <td>0.959026</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.969009</td>\n",
       "      <td>0.984979</td>\n",
       "      <td>0.984972</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.924496</td>\n",
       "      <td>0.962503</td>\n",
       "      <td>0.962322</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">llama</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.808244</td>\n",
       "      <td>0.905726</td>\n",
       "      <td>0.905480</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.559060</td>\n",
       "      <td>0.782145</td>\n",
       "      <td>0.799067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.605609</td>\n",
       "      <td>0.798117</td>\n",
       "      <td>0.799651</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.918734</td>\n",
       "      <td>0.959396</td>\n",
       "      <td>0.959505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">sonnet</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.784838</td>\n",
       "      <td>0.880936</td>\n",
       "      <td>0.881063</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.571282</td>\n",
       "      <td>0.862425</td>\n",
       "      <td>0.853626</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.791883</td>\n",
       "      <td>0.899259</td>\n",
       "      <td>0.898415</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.946775</td>\n",
       "      <td>0.973819</td>\n",
       "      <td>0.973765</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           MCC  Accuracy        F1\n",
       "Column       Task                                                 \n",
       "base_polnli  event extraction         0.813742  0.906774  0.907042\n",
       "             hatespeech and toxicity  0.841410  0.946036  0.945248\n",
       "             stance detection         0.920600  0.961546  0.961488\n",
       "             topic classification     0.926241  0.963834  0.963826\n",
       "large_polnli event extraction         0.819699  0.909567  0.909838\n",
       "             hatespeech and toxicity  0.881535  0.959360  0.959026\n",
       "             stance detection         0.969009  0.984979  0.984972\n",
       "             topic classification     0.924496  0.962503  0.962322\n",
       "llama        event extraction         0.808244  0.905726  0.905480\n",
       "             hatespeech and toxicity  0.559060  0.782145  0.799067\n",
       "             stance detection         0.605609  0.798117  0.799651\n",
       "             topic classification     0.918734  0.959396  0.959505\n",
       "sonnet       event extraction         0.784838  0.880936  0.881063\n",
       "             hatespeech and toxicity  0.571282  0.862425  0.853626\n",
       "             stance detection         0.791883  0.899259  0.898415\n",
       "             topic classification     0.946775  0.973819  0.973765"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics(test, preds = ['base_polnli', 'large_polnli', 'llama', 'sonnet'], group_by = 'task')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
